package com.za.plugin.transfer.form;


import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLExpr;
import com.alibaba.druid.sql.ast.SQLObject;
import com.alibaba.druid.sql.ast.expr.SQLMethodInvokeExpr;
import com.alibaba.druid.sql.ast.statement.SQLSelectItem;
import com.alibaba.druid.sql.ast.statement.SQLSelectQuery;
import com.alibaba.druid.sql.ast.statement.SQLSelectStatement;
import com.alibaba.druid.sql.ast.statement.SQLSubqueryTableSource;
import com.alibaba.druid.sql.dialect.mysql.ast.statement.MySqlSelectQueryBlock;
import com.alibaba.druid.sql.dialect.mysql.parser.MySqlStatementParser;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlASTVisitorAdapter;

import java.lang.reflect.Field;
import java.util.*;

// if 里也不支持别名，需要替换。 只能做一半的功能，无法彻底解决
public class IfAliasFormTransfer implements FormTransfer {
    @Override
    public boolean isSupport(String sql) {
        sql = cleanSql(sql);
        return sql.contains("if(");
    }

    private static String cleanSql(String sql) {
        return sql.replaceAll("if\\s*\\(\\s*", "if(").replaceAll("\\s*,\\s*", ",");
    }

    @Override
    public String transfer(String sql) {
        sql = cleanSql(sql);
        return truncateIfAlias(sql);
    }

    private static String truncateIfAlias(String sql) {
        MySqlStatementParser parser = new MySqlStatementParser(sql);
        SQLSelectStatement sqlSelectStatement = (SQLSelectStatement) parser.parseStatement();
        SQLSelectQuery query = sqlSelectStatement.getSelect().getQuery();
        query.accept(new MySqlASTVisitorAdapter() {
            @Override
            public boolean visit(SQLSubqueryTableSource queryBlock) {
                processQueryBlock((MySqlSelectQueryBlock) queryBlock.getSelect().getQuery());
                return super.visit(queryBlock);
            }

            private void processQueryBlock(MySqlSelectQueryBlock queryBlock) {
                List<SQLSelectItem> selectList = queryBlock.getSelectList();
                HashMap<String, SQLExpr> map = new HashMap<>();
                for (SQLSelectItem item : selectList) {
                    if (item.getAlias() != null) {
                        map.put(item.getAlias().replace("\"",""), item.getExpr());
                    }
                }
                Set<Object> set = new HashSet<>(selectList);
                queryBlock.accept(new MySqlASTVisitorAdapter() {
                    @Override
                    public boolean visit(SQLMethodInvokeExpr expr) {
                        SQLSelectItem sqlSelectItem = findSQLSelectItem(expr);
                        if (!set.contains(sqlSelectItem)) {
                            return super.visit(expr);
                        }

                        if (expr.getMethodName().equalsIgnoreCase("if")) {
                            if (expr.getParameters().size() == 3) {
                                String content = SQLUtils.toMySqlString(expr);
                                for (Map.Entry<String, SQLExpr> entry : map.entrySet()) {
                                    String key = entry.getKey();
                                    content = content.replaceAll("(,|\\(|\\)|\\s+)" + key + "(,|\\(|\\)|\\s+)", "$1" + SQLUtils.toMySqlString(entry.getValue()) + "$2");
                                }
                                SQLMethodInvokeExpr sqlExpr = (SQLMethodInvokeExpr) SQLUtils.toSQLExpr(content);
                                try {
                                    Field parameters = expr.getClass().getDeclaredField("parameters");
                                    parameters.setAccessible(true);
                                    parameters.set(expr, sqlExpr.getParameters());
                                } catch (Exception e) {
                                    throw new RuntimeException(e);
                                }
                            }
                        }
                        return super.visit(expr);
                    }
                });
            }
        });
        return SQLUtils.toMySqlString(query);
    }

    private static SQLSelectItem findSQLSelectItem(SQLMethodInvokeExpr expr) {
        SQLObject sqlExpr = expr;
        while (!(sqlExpr.getParent() instanceof SQLSelectItem)) {
            sqlExpr = sqlExpr.getParent();
        }
        return (SQLSelectItem) sqlExpr.getParent();
    }


    private int getBegin(String sql, int ifIdx) {
        int selectIdx = 0, ret = 0;
        while (true) {
            selectIdx = sql.indexOf("select ", selectIdx + 7);
            if (selectIdx >= 0 && selectIdx < ifIdx) {
                ret = selectIdx;
            } else {
                break;
            }
        }
        return ret;
    }

    public static class Main2 {
        public static void main(String[] args) {
            String sql = "select a.t tt,if(c_tt>0, tt,tt_) cc from (select c.a t,c.g dd,if(t<0,t, tt)from c group by t having t > 0 ) a group by tt desc";
            System.out.println(truncateIfAlias(sql));
        }
    }


    //        if (sql.contains("if")) {
//            // 把 group by 别名替换为原来的模样
//            int ifIdx = sql.indexOf("if(");
//            while (ifIdx >= 0) {
//                Content expectContent = StrUtil.getExpectContent(sql, ifIdx + 2);
//                assert expectContent != null;
//                int end = expectContent.getEnd();
//                List<String> segmentList = StrUtil.getSegment(sql, ifIdx + 3, end - 1);
//                String substring = expectContent.getExpectStr();
//                int begin = getBegin(sql, ifIdx);
//                Map<String, String> selectProperties = SqlUtil.getSelectProperties(sql, begin);
//                for (String segment : segmentList) {
//                    for (Map.Entry<String, String> entry : selectProperties.entrySet()) {
//                        String key = entry.getKey();
//                        String value = entry.getValue();
//                        int idx = value.indexOf(key);
//                        // 防止出现死循环 sum(a) a
//                        int keyLen = key.length();
//                        if (key.equals(value) || (idx != -1 && (value.contains(key) && (idx == 0 ||
//                                StrUtil.isValidSeparator(value.charAt(idx - 1))) && (idx + keyLen >= value.length()
//                                || StrUtil.isValidSeparator(value.charAt(idx + keyLen)))))) {
//                            continue;
//                        }
//                        if (segment.contains(key)) {
//                            substring = StrUtil.swapWord(substring, key, value);
//                        }
//                    }
//                }
//                sql = sql.substring(0, ifIdx) + "if( " + substring + ")"
//                        + sql.substring(end);
//                ifIdx = sql.indexOf("if(", ifIdx + 3);
//            }
//        }
}
